Skip to content

FP8 kv cache quantization#4563

Draft
CUHKSZzxy wants to merge 16 commits intoInternLM:mainfrom
CUHKSZzxy:feat/fp8-kv-cache-quant
Draft

FP8 kv cache quantization#4563
CUHKSZzxy wants to merge 16 commits intoInternLM:mainfrom
CUHKSZzxy:feat/fp8-kv-cache-quant

Conversation

@CUHKSZzxy
Copy link
Copy Markdown
Collaborator

@CUHKSZzxy CUHKSZzxy commented Apr 29, 2026

Summary

This PR adds PyTorch CUDA/Triton FP8 KV-cache support with a concise public policy surface aligned with vLLM-style normal FP8 behavior.

Supported policies:

  • fp8 / fp8_e4m3 -> scalar-scale torch.float8_e4m3fn KV cache
  • fp8_e5m2 -> scalar-scale torch.float8_e5m2 KV cache

The implementation intentionally does not expose per-attention-head / per-token-head FP8 KV-cache modes or dynamic calculate_kv_scales-style calibration. Those paths were removed from the public surface to keep this PR focused on the normal scalar-scale FP8 path.

What Changed

Public API / Config

  • Adds QuantPolicy.FP8 for E4M3 scalar-scale FP8 KV cache.
  • Adds QuantPolicy.FP8_E5M2 for E5M2 scalar-scale FP8 KV cache.
  • Updates CLI parsing so:
    • --quant-policy fp8 defaults to E4M3
    • --quant-policy fp8_e4m3 maps to E4M3
    • --quant-policy fp8_e5m2 maps to E5M2
  • Keeps unsupported/deprecated aliases out of the public CLI surface.
  • Keeps deprecated dynamic scale calculation and per-attention-head/per-token-head policies out of this PR.

Cache Allocation

  • Stores normal FP8 KV cache payloads directly as torch.float8_e4m3fn or torch.float8_e5m2.
  • Avoids allocating per-token/head FP8 scale metadata for normal FP8.
  • Keeps existing INT4/INT8/TurboQuant metadata behavior unchanged.

Attention Runtime

  • Adds scalar k_scale / v_scale buffers on PyTorch attention layers.
  • Defaults scalar scales to 1.0.
  • Warns once when using normal E4M3 FP8 with default k_scale=v_scale=1.0.
  • Threads scalar scales through the PyTorch CUDA attention path for cache fill, flatten, prefill, and decode.

Kernels

  • Adds scalar-scale FP8 cache fill:
    • quantizes as value / scale
    • clamps to the target FP8 dtype range
    • stores directly into FP8 cache tensors
  • Adds scalar-scale FP8 flatten:
    • reads FP8 cache
    • dequantizes as stored * scale
    • writes contiguous K/V states for prefill paths
  • Adds scalar-scale FP8 paged decode support:
    • applies k_scale in QK score computation
    • applies v_scale before the PV accumulation
    • avoids materializing full dequantized K/V tensors in decode

Tests

  • Adds CLI/config tests for FP8 policy parsing and dtype mapping.
  • Adds cache descriptor tests confirming normal FP8 does not allocate extra quant metadata.
  • Adds kernel tests for:
    • FP8 scalar cache fill
    • FP8 scalar cache flatten
    • FP8 scalar paged attention decode
    • E4M3 and E5M2 variants
  • Adds SM80/SM90-aware guards for Triton FP8 dtype support differences.
  • Extends existing quant-policy pipeline tests for FP8 behavior while avoiding known unsupported architecture paths.

Scope

This PR focuses on the LMDeploy PyTorch CUDA/Triton attention path.

Out of scope:

  • TurboMind FP8 KV-cache support
  • dlinfer/Ascend FP8 KV-cache support
  • per-attention-head or per-token-head FP8 KV-cache quantization
  • dynamic runtime KV-scale calibration
  • dataset/llm-compressor calibration workflows
  • FlashMLA-specific FP8 behavior beyond existing guarded paths

Rationale

The previous experimental direction included per-token/head scale metadata and dynamic scale calculation. That made the implementation heavier and less aligned with the common normal FP8 KV-cache path used by vLLM.

This PR keeps the first upstreamable FP8 KV-cache feature smaller and easier to validate:

  • normal FP8 uses scalar layer-level scales,
  • no per-token/head scale metadata is allocated,
  • no per-step absmax scale calculation runs on the hot path,
  • cache fill/decode/flatten behavior is explicit and test-covered,
  • unsupported or deprecated modes are not exposed as public API.

Test Plan

python -m pytest -q tests/test_lmdeploy/test_fp8_kv_cache_policy.py
python -m pytest -q tests/pytorch/kernel/test_fill_kv_cache.py
python -m pytest -q tests/pytorch/kernel/test_flatten_kv_cache.py
python -m pytest -q tests/pytorch/kernel/test_paged_attention.py
python -m pytest -q tests/test_lmdeploy/test_quant_policy.py

CUHKSZzxy and others added 16 commits April 23, 2026 14:57
Adds FP8 KV cache quantization (QuantPolicy.FP8 = 16) using
torch.float8_e4m3fn with per-token symmetric scale (no zero point).

Key design:
- Reuses existing fill_kv_cache_blocked_fp8() with group_size=head_dim
  for per-token scale semantics in the fill path
- Dequant in flatten_kv_cache and paged_attention via x.to(f32)*scale
- Scale tensor shape [..., 1]: symmetric, no zero point
- No bit packing (head_dim unchanged, unlike INT4/TURBO_QUANT)

Also fixes pre-existing TestFillKVCacheBlockedFP8 test failures caused
by calling .max() on float8_e4m3fn tensors (cast to float32 first).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Avoid constructing a temporary cu_seqlen_q tensor in the FP8 cache-fill path by letting fill_kv_cache_blocked_fp8 consume the existing q_start_loc and q_seq_length metadata directly. The kernel keeps the old cumulative-seqlen mode for direct callers via a USE_CU_SEQLEN constexpr.

Move default paged-decode FP8 dequant scaling across the attention dot products. K scales are applied after QK, and V scales are applied to probabilities before PV, which preserves the per-token/head scale algebra while avoiding full K/V tile dequantization in the hot decode loop.

Add a focused FP8 paged-attention test that compares against a dequantized-FP8 reference, including a split-head-dim case, so the fused scale placement is covered without conflating it with expected quantization error.
Split normal FP8 KV cache from the dynamic per-token/head FP8 path. Normal fp8/fp8_e4m3 and fp8_e5m2 now use scalar K/V scales with FP8 cache tensors and no k_scales_zeros/v_scales_zeros metadata allocation, while fp8_per_token_head variants keep the existing per-token/head scale-cache behavior.

Thread scalar k_scale/v_scale through PyTorch attention dispatch, cache fill, flatten, and paged decode kernels so normal FP8 can quantize on cache write and apply scalar dequant in decode/prefill without materialized metadata tensors. Add optional one-shot calculate_kv_scales support and guard CUDA graph capture while scale calculation is pending, mirroring vLLM's eager first-pass behavior.

Add focused CLI/config/cache descriptor tests and scalar/per-token FP8 kernel reference coverage. Validation: py_compile on changed runtime/kernel/test files; pytest -q tests/test_lmdeploy/test_fp8_kv_cache_policy.py; git diff --check. CUDA kernel tests were not run because nvidia-smi cannot communicate with the driver in this environment.
Remove the deprecated-style dynamic KV scale calculation path and keep normal FP8 on the vLLM-aligned scalar-scale behavior with default scales.

Drop the experimental per-token/head FP8 policy and tests so the public surface only exposes fp8, fp8_e4m3, and fp8_e5m2.

Sadly we have to remove some potentially useful features to keep this PR concise and solid.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant